//===- PlanMemory.cpp ----Plan Buffer Memory Address ----------------------===//
//
// Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//    http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
//===----------------------------------------------------------------------===//

#include "bishengir/Dialect/HIVM/Transforms/PlanMemory.h"
#include "bishengir/Dialect/HACC/Utils/Utils.h"
#include "bishengir/Dialect/HIVM/Transforms/AllocToPointerCast.h"
#include "bishengir/Dialect/HIVM/Transforms/NormalizeLoopIterator.h"
#include "bishengir/Dialect/HIVM/Utils/Utils.h"
#include "bishengir/Dialect/MemRefExt/IR/MemRefExtImpl.h"
#include "bishengir/Dialect/Utils/Util.h"

#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#include "llvm/Support/Debug.h"
#include "llvm/Support/LogicalResult.h"

#define DEBUG_TYPE "hivm-plan-memory"
#define LDBG(X) LLVM_DEBUG(llvm::dbgs() << X)

namespace mlir {
#define GEN_PASS_DEF_PLANMEMORY
#include "bishengir/Dialect/HIVM/Transforms/Passes.h.inc"
} // namespace mlir

using namespace mlir;
using namespace hivm;
using namespace util;

namespace {
/// TODO: Obtain information from the same platform in the future
constexpr const int ubAlignSize = 32 * 8;
constexpr const int ubSpaceSize = 192 * 1024 * 8;
constexpr const int l1AlignSize = 32 * 8;
constexpr const int l1SpaceSize = 512 * 1024 * 8;
constexpr const int l0cAlignSize = 512 * 8;
constexpr const int l0cSpaceSize = 128 * 1024 * 8;

bool isReusableCastOp(hivm::VCastOp &castOp, Value output, Value input) {
  auto rank = dyn_cast<MemRefType>(output.getType()).getRank();
  if (rank > 1 || !isLastDimContiguous(output) || !isLastDimContiguous(input)) {
    // can only reuse 1d cast library
    return false;
  }
  return true;
}

bool isReusableNarrowWidth(Operation *op, Value output, Value input) {
  if (auto castOp = dyn_cast_if_present<hivm::VCastOp>(op)) {
    return isReusableCastOp(castOp, output, input);
  }
  return true;
}

/// Memory can inplace with rules as follows:
///   Scene1: Output has same width with input.
///   Scene2: Output width is smaller than input except VCastOp.
///   Scene3: VCastOp output width is smaller than input, and rank equal to 1
///          with the last dim contiguous. If the last dim is not contiguous, it
///          may be lifted to rank2.
bool isReusableOperands(Operation *op, HIVMStructuredOp &hivmOp) {
  auto output = hivmOp.getDpsInits().front();
  auto outputBitWidth =
      getElementTypeOrSelf(output.getType()).getIntOrFloatBitWidth();
  for (auto input : hivmOp.getDpsInputs()) {
    auto inputBitWidth =
        getElementTypeOrSelf(input.getType()).getIntOrFloatBitWidth();
    if (inputBitWidth == outputBitWidth)
      return true;
    if (inputBitWidth % outputBitWidth == 0)
      return isReusableNarrowWidth(op, output, input);
  }
  return false;
}

inline bool hasInlineBroadcastOrTransposeAttr(Operation *op) {
  auto hivmOp = cast<hivm::HIVMStructuredOp>(op);
  return hivmOp.existInlineBroadcastLoopDims() ||
         hivmOp.existInlineTransposeLoopDims();
}

/// ExtraBuffer can inplace with only one extra buffer generated for scalar or
/// OTF Brc. Otherwise, extra buffer may be generated by compound instructions.
inline bool isReusableExtraBuffer(Operation *op) {
  auto extraBufferOp = dyn_cast_if_present<ExtraBufferOpInterface>(op);
  if (!extraBufferOp || extraBufferOp.getExtraBuffers().empty()) {
    return true;
  }
  if (extraBufferOp.getExtraBuffers().size() > 1) {
    return false;
  }
  return extraBufferOp.shouldAllocExtraBufferForScalarOrOTFBrc();
}

bool isStaticShapeSame(Operation *op, const Value &genBuffer,
                       const Value &killBuffer) {
  Value genOperand = genBuffer;
  Value killOperand = killBuffer;
  for (auto operand : op->getOperands()) {
    if (genBuffer == utils::tracebackMemRef(operand)) {
      genOperand = operand;
    }
    if (killBuffer == utils::tracebackMemRef(operand)) {
      killOperand = operand;
    }
  }
  auto genOperandShapes = cast<ShapedType>(genOperand.getType()).getShape();
  auto killOperandShapes = cast<ShapedType>(killOperand.getType()).getShape();
  for (auto [genShape, killShape] :
       llvm::zip(genOperandShapes, killOperandShapes)) {
    if (genShape != killShape)
      return false;
  }
  return true;
}

} // namespace

void MemLivenessAnalysis::build() {
  Region &funcRegion = func_.getBody();
  Liveness live(func_);
  // Recursively obtaining IR information.
  RecursionIR(&funcRegion, live);
  // the lifetime of the buffer.
  GenerateBufferLife();
  InitializeInplacePairList();
}

bool MemLivenessAnalysis::isLocalMemPlan() const {
  return planMode == MemPlanMode::LOCAL_MEM_PLAN;
}

bool MemLivenessAnalysis::isGlobalWorkSpaceMemPlan() const {
  return planMode == MemPlanMode::GLOBAL_WORKSPACE_PLAN;
}

void MemLivenessAnalysis::RecursionIR(Region *region, Liveness live) {
  auto result = region->walk<WalkOrder::PreOrder>([&](Operation *op) {
    // recursive control flow
    if (auto ifOp = dyn_cast<scf::IfOp>(op)) {
      RecursiveIfOp(ifOp, live);
      return WalkResult::skip();
    } else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
      RecursiveForOp(forOp, live);
      return WalkResult::skip();
    }
    if (auto whileOp = dyn_cast<scf::WhileOp>(op)) {
      RecursiveWhileOp(whileOp, live);
      return WalkResult::skip();
    }

    // process operation
    auto curOpInfo = UpdateLinearOperation(op);
    auto aliasPairs = getOperationAliasInfo(op);
    if (!aliasPairs.empty() && !isa<arith::SelectOp>(op)) {
      for (auto aliasPair : aliasPairs) {
        UpdateBufferAlias(aliasPair.first, aliasPair.second);
      }
    } else if (isLocalMemPlan() && dyn_cast<memref::AllocOp>(op)) {
      if (failed(CheckLocalBufferAllocOp(op))) {
        return WalkResult::interrupt();
      }
      UpdateOpBufferInfo(op, op->getResults());
      return WalkResult::advance();
    } else if (isGlobalWorkSpaceMemPlan() &&
               dyn_cast<bishengir::memref_ext::AllocWorkspaceOp>(op)) {
      UpdateOpBufferInfo(op, op->getResults());
      return WalkResult::advance();
    } else if (auto loadOp = dyn_cast<memref::LoadOp>(op)) {
      OpKillHandle(curOpInfo, live, op->getBlock());
    } else if (auto storeOp = dyn_cast<memref::StoreOp>(op)) {
      UpdateStoreOpInfo(curOpInfo, storeOp.getMemRef(), live);
    } else if (auto dstStyleOp = dyn_cast<DestinationStyleOpInterface>(op)) {
      // Process the operation of hivm instructions as follows:
      // hivm.hir.copy ins(%0 : memref<16xf16, #hivm.address_space<gm>>)
      //              outs(%1 : memref<16xxf16, #hivm.address_space<ub>>)
      // need to handle kill buffer.
      UpdateInitAndResAlias(dstStyleOp);
      UpdateOpGenInfo(curOpInfo, llvm::to_vector(dstStyleOp.getDpsInits()));
      UpdateOpTempGenInfo(curOpInfo);
      OpKillHandle(curOpInfo, live, op->getBlock());
    } else if (auto selectOp = dyn_cast<arith::SelectOp>(op)) {
      UpdateBufferAlias(selectOp.getResult(), selectOp.getTrueValue(),
                        /*hasCond=*/true, /*isIgnoreInplace=*/true);
      UpdateBufferAlias(selectOp.getResult(), selectOp.getFalseValue(),
                        /*hasCond=*/true, /*isIgnoreInplace=*/true);
      OpKillHandle(curOpInfo, live, op->getBlock());
    } else if (auto markOp = dyn_cast<annotation::MarkOp>(op)) {
      UpdateMultiBufferInfo(markOp);
    } else if (auto conditionOp = dyn_cast<scf::ConditionOp>(op)) {
      UpdateConditionOpBufferAlias(conditionOp);
    } else if (auto condBrOp = dyn_cast<cf::CondBranchOp>(op)) {
      UpdateBranchOpAlias(condBrOp.getTrueDest(),
                          condBrOp.getTrueDestOperands());
      UpdateBranchOpAlias(condBrOp.getFalseDest(),
                          condBrOp.getFalseDestOperands());
    } else if (auto brOp = dyn_cast<cf::BranchOp>(op)) {
      UpdateBranchOpAlias(brOp.getDest(), brOp.getDestOperands());
    } else if (failed(CheckIfUnknownOpTouchBuffer(op))) {
      return WalkResult::interrupt();
    }
    return WalkResult::advance();
  });
  if (result == WalkResult::interrupt()) {
    llvm_unreachable("PlanMemory Traverse IR Failed! ");
  }
}

void MemLivenessAnalysis::UpdateInitAndResAlias(
    DestinationStyleOpInterface dstStyleOp) {
  auto results = dstStyleOp->getResults();
  if (results.empty()) {
    return;
  }
  for (auto [res, init] :
       llvm::zip(dstStyleOp->getResults(), dstStyleOp.getDpsInits())) {
    auto iter = buffer2AliasVec.find(init);
    if (iter == buffer2AliasVec.end()) {
      continue;
    }
    auto tensorType = dyn_cast_or_null<TensorType>(res.getType());
    if (tensorType) {
      UpdateBufferAlias(res, init);
    }
  }
}

OpInfo *MemLivenessAnalysis::UpdateLinearOperation(Operation *op) {
  auto opInfo = std::make_unique<OpInfo>(op, seqIndex++);
  auto curOpInfo = opInfo.get();
  linearOperation.push_back(std::move(opInfo));
  return curOpInfo;
}

void MemLivenessAnalysis::UpdateForOpBufferAlias(scf::ForOp forOp) {
  if (forOp.getResults().empty()) {
    return;
  }
  if (!forOp.getRegionIterArgs().empty()) {
    assert(forOp.getYieldedValues().size() == forOp.getRegionIterArgs().size());
    assert(forOp.getInitArgs().size() == forOp.getRegionIterArgs().size());
    for (auto [i, arg] : llvm::enumerate(forOp.getRegionIterArgs())) {
      // yielded values alias region iter args.
      UpdateBufferAlias(forOp.getYieldedValues()[i], arg);
    }
  }
  assert(forOp->getResults().size() == forOp.getYieldedValues().size());
  for (auto [i, arg] : llvm::enumerate(forOp.getYieldedValues())) {
    // forOp result values alias region iter yielded values.
    UpdateBufferAlias(forOp->getResult(i), arg);
  }
}

void MemLivenessAnalysis::RecursiveForOp(scf::ForOp forOp, Liveness live) {
  // Process the operation of ForOp as follows:
  // alloca %allocA
  // %0 = scf.for %arg4 = %c0 to %c1024 step %c128 iter_args(%arg5 = %4)->
  //      (memref<16x16x16xf16, #hivm.address_space<ub>>):
  //          def(allocA)
  //          ...
  //          scf.yield %alloc0 : memref<16xf16,#hivm.address_space<ub>>
  // need to handle kill buffer.
  auto forBeginSeq = UpdateLinearOperation(forOp.getOperation());
  UpdateOpGenInfo(forBeginSeq, GetLiveBuffersInLoop(forOp, live));
  UpdateForOpInitArgsAlias(forOp);
  RecursionIR(&forOp.getRegion(), live);
  UpdateForOpBufferAlias(forOp);
  auto forEndSeq = UpdateLinearOperation(forOp.getOperation());
  OpKillHandle(forEndSeq, live, forOp->getBlock());
}

void MemLivenessAnalysis::UpdateWhileOpInitArgsAlias(scf::WhileOp whileOp) {
  if (whileOp.getInits().empty()) {
    return;
  }
  assert(whileOp.getInits().size() == whileOp.getRegionIterArgs().size());
  for (auto [iterArg, init] :
       llvm::zip(whileOp.getRegionIterArgs(), whileOp.getInits())) {
    // init args alias region iter args.
    UpdateBufferAlias(iterArg, init);
  }
}

void MemLivenessAnalysis::UpdateWhileOpBufferAlias(scf::WhileOp whileOp) {
  if (whileOp.getResults().empty()) {
    return;
  }
  if (!whileOp.getRegionIterArgs().empty()) {
    assert(whileOp.getYieldedValues().size() ==
           whileOp.getRegionIterArgs().size());
    assert(whileOp.getInits().size() == whileOp.getRegionIterArgs().size());
    for (auto [yieldVal, iterArg] :
         llvm::zip(whileOp.getYieldedValues(), whileOp.getRegionIterArgs())) {
      // yielded values alias region iter args.
      UpdateBufferAlias(yieldVal, iterArg);
    }
  }
  assert(whileOp->getResults().size() == whileOp.getAfterArguments().size());
  for (auto [i, arg] : llvm::enumerate(whileOp.getAfterArguments())) {
    // whileOp result values alias after-region args.
    UpdateBufferAlias(whileOp->getResult(i), arg);
  }
}

void MemLivenessAnalysis::UpdateConditionOpBufferAlias(
    scf::ConditionOp conditionOp) {
  auto whileOp = dyn_cast<scf::WhileOp>(conditionOp->getParentOp());
  if (!whileOp || whileOp.getAfterArguments().empty()) {
    return;
  }
  assert(whileOp.getAfterArguments().size() == conditionOp.getArgs().size());
  for (auto [afterArg, condArg] :
       llvm::zip(whileOp.getAfterArguments(), conditionOp.getArgs())) {
    // whileOp after arguments alias condition operands.
    UpdateBufferAlias(afterArg, condArg);
  }
}

void MemLivenessAnalysis::RecursiveWhileOp(scf::WhileOp whileOp,
                                           Liveness live) {
  // Process the operation of WhileOp as follows:
  // alloca %allocA
  // %0 = scf.while (%arg0 = %init)
  //          %1 = arith.cmpi eq, %arg1, %const
  //          scf.condition(%1) %shared
  //      do (%shared)
  //          def(allocA)
  //          ...
  //          scf.yield %allocA
  // need to handle kill buffer.
  auto whileBeginSeq = UpdateLinearOperation(whileOp.getOperation());
  UpdateOpGenInfo(whileBeginSeq, GetLiveBuffersInLoop(whileOp, live));
  UpdateWhileOpInitArgsAlias(whileOp);
  RecursionIR(&whileOp.getBefore(), live);
  RecursionIR(&whileOp.getAfter(), live);
  UpdateWhileOpBufferAlias(whileOp);
  auto whileEndSeq = UpdateLinearOperation(whileOp.getOperation());
  OpKillHandle(whileEndSeq, live, whileOp->getBlock());
}

void MemLivenessAnalysis::UpdateForOpInitArgsAlias(scf::ForOp forOp) {
  if (forOp.getInitArgs().empty()) {
    return;
  }
  assert(forOp.getInitArgs().size() == forOp.getRegionIterArgs().size());
  for (auto [i, arg] : llvm::enumerate(forOp.getInitArgs())) {
    // init args alias region iter args.
    UpdateBufferAlias(forOp.getRegionIterArgs()[i], arg);
  }
}

void MemLivenessAnalysis::UpdateIfOpBufferAlias(scf::IfOp ifOp,
                                                scf::YieldOp yieldOp) {
  if (ifOp.getResults().empty()) {
    return;
  }
  assert(ifOp->getResults().size() == yieldOp->getOperands().size());
  for (auto [i, arg] : llvm::enumerate(yieldOp->getOperands())) {
    // Multiple buffers involved, requiring one-to-one correspondence.
    UpdateBufferAlias(ifOp->getResult(i), arg, /*hasCond=*/true);
  }
}

void MemLivenessAnalysis::RecursiveIfOp(scf::IfOp ifOp, Liveness live) {
  // Process the operation of IfOp as follows:
  // %0 = scf.if %cond -> (memref<16xf16, #hivm.address_space<ub>>)
  //        scf.yield %alloc0: memref<16xf16, #hivm.address_space<ub>>
  //      else:
  //        scf.yield %alloc1 : memref<16xf16, #hivm.address_space<ub>>
  (void)UpdateLinearOperation(ifOp.getOperation());
  RecursionIR(&ifOp.getThenRegion(), live);
  auto curIfElse = UpdateLinearOperation(ifOp.getOperation());
  UpdateIfOpBufferAlias(ifOp, ifOp.thenYield());

  auto curIfEnd = curIfElse;
  if (ifOp.elseBlock()) {
    RecursionIR(&ifOp.getElseRegion(), live);
    curIfEnd = UpdateLinearOperation(ifOp.getOperation());
    UpdateIfOpBufferAlias(ifOp, ifOp.elseYield());
  }
  OpKillHandle(curIfEnd, live, ifOp->getBlock());
}

void MemLivenessAnalysis::UpdateBranchOpAlias(Block *brBlock,
                                              OperandRange destOperands) {
  if (destOperands.empty()) {
    return;
  }
  assert(brBlock->getArguments().size() == destOperands.size());
  for (auto [destArg, destOperand] :
       llvm::zip(brBlock->getArguments(), destOperands)) {
    // BranchOp dest operands alias dest block arguments.
    UpdateBufferAlias(destArg, destOperand, /*hasCond=*/true);
  }
}

SmallVector<Value>
MemLivenessAnalysis::GetLiveBuffersInLoop(LoopLikeOpInterface loopOp,
                                          Liveness live) {
  SmallVector<Value> allocBeforeLoopBuffers;
  const auto *liveBlockInfo = live.getLiveness(loopOp->getBlock());
  auto currentLiveValues =
      liveBlockInfo->currentlyLiveValues(loopOp.getOperation());
  if (currentLiveValues.empty()) {
    return allocBeforeLoopBuffers;
  }
  // The gen buffer of the same operation must ensure the order of priority.
  SetVector<Value> currentLiveValuesOrder;
  for (auto buffer : currentLiveValues) {
    currentLiveValuesOrder.insert(buffer);
  }
  for (const Value &operand : currentLiveValuesOrder) {
    auto aliasBuffers = GetAliasBuffers(operand);
    aliasBuffers.insert(operand);
    for (auto Buffer : aliasBuffers) {
      auto iter = buffer2status.find(Buffer);
      if (iter != buffer2status.end())
        allocBeforeLoopBuffers.push_back(Buffer);
    }
  }
  return allocBeforeLoopBuffers;
}

void MemLivenessAnalysis::UpdateMultiBufferInfo(annotation::MarkOp markOp) {
  auto attrDict = markOp->getAttrDictionary();
  if (attrDict.empty()) {
    return;
  }
  if (!attrDict.contains(hivm::MultiBufferAttr::name)) {
    return;
  }
  auto multiBufferValAttr = attrDict.get(hivm::MultiBufferAttr::name);
  assert(isa<IntegerAttr>(multiBufferValAttr) &&
         "multi buffer value must be integer!");
  auto valAttr = cast<IntegerAttr>(multiBufferValAttr);
  if (valAttr.getInt() == 1) {
    // Num is 1, which is a singebuffer and does not require any processing.
    return;
  }
  buffer2MultiNum[markOp.getSrc()] = static_cast<uint64_t>(valAttr.getInt());
}

LogicalResult
MemLivenessAnalysis::CheckLocalBufferAllocOp(Operation *op) const {
  auto allocOp = dyn_cast<memref::AllocOp>(op);
  assert(allocOp && "must be alloc op");
  auto memorySpaceAttr = GetBufferSpaceAttr(allocOp.getResult());
  if (isLocalBuffer(memorySpaceAttr)) {
    return success();
  }
  allocOp.getOperation()->emitError("Alloc buffer not at UB space! ");
  return failure();
}

bool MemLivenessAnalysis::isSkippableOp(Operation *op) const {
  return isa<func::ReturnOp, scf::YieldOp, memref::DimOp, hivm::DebugOp,
             hivm::DCCIOp>(op);
}

LogicalResult
MemLivenessAnalysis::CheckIfUnknownOpTouchBuffer(Operation *op) const {
  if (isSkippableOp(op) || isGlobalWorkSpaceMemPlan()) {
    // This scene can be ignored.
    return success();
  }
  if (isOpTouchLocalBuffer(op)) {
    op->emitError("PlanMemory Fail : Unrecognized type of Operation touches "
                  "local buffer!");
    return failure();
  }
  return success();
}

void MemLivenessAnalysis::UpdateExtraBufferIgnoreInplace(
    const ValueRange &results) {
  // mark the temp buffer as ignoring inplace
  for (Value operand : results) {
    auto it = bufferInfos.find(operand);
    assert(it != bufferInfos.end());
    it->second.ignoreInplace = true;
  }
}

void MemLivenessAnalysis::UpdateOpTempGenInfo(OpInfo *opInfo) {
  // Handld Temp buffer Scenarios:
  SmallVector<Value, 1> resultVec;
  if (auto extraBufferOp =
          dyn_cast<ExtraBufferOpInterface>(opInfo->operation)) {
    auto extraBuffers = extraBufferOp.getExtraBuffers();
    resultVec.insert(resultVec.end(), extraBuffers.begin(), extraBuffers.end());
    UpdateExtraBufferIgnoreInplace(resultVec);
    UpdateOpGenInfo(opInfo, resultVec);
  }
}

void MemLivenessAnalysis::UpdateBuffer2AliasVec(
    const SetVector<Value> &buffers, const SetVector<Value> &aliasBuffers,
    bool hasCond) {
  for (auto buffer : buffers) {
    for (auto aliasValue : aliasBuffers) {
      auto bufferCondPair = FindBufferCondPair(buffer, aliasValue);
      if (bufferCondPair) {
        (*bufferCondPair)->second = (*bufferCondPair)->second || hasCond;
      } else {
        buffer2AliasVec[buffer].push_back(std::make_pair(aliasValue, hasCond));
      }
    }
  }
}

void MemLivenessAnalysis::UpdateBufferAlias(Value buffer, Value aliasBuffer,
                                            bool hasCond,
                                            bool isIgnoreInplace) {
  SetVector<Value> buffers = GetAliasBuffers(buffer);
  SetVector<Value> aliasBuffers = GetAliasBuffers(aliasBuffer);
  buffers.insert(buffer);
  aliasBuffers.insert(aliasBuffer);

  // update alias map info for each buffer
  // e.g. if A alias B, C alias D, now update:
  // A alias B,C,D; B alias A,C,D; and update C,D's condition
  UpdateBuffer2AliasVec(buffers, aliasBuffers, hasCond);

  // e.g. if A alias B, C alias D, now update:
  // C alias A,B,D; D alias A,B,C; and update A,B's condition
  UpdateBuffer2AliasVec(aliasBuffers, buffers, hasCond);

  // AllocOp is DEFFINED, not AllocOp is UNDEFFINED.
  // The buffer of UNDEFFINED is only used as an alias buffer to
  // participate in life interval judgment.
  if (!memref_ext::isDefiningOpAllocLike(buffer)) {
    buffer2status[buffer] = BufferStatus::UNDEFFINED;
  }

  // mark the alias buffer as ignoring Inplace if it is not generated by
  // memref.alloc.
  auto it = bufferInfos.find(aliasBuffer);
  if (isIgnoreInplace && it != bufferInfos.end()) {
    it->second.ignoreInplace = true;
  }
}

std::optional<BufferCondPair *>
MemLivenessAnalysis::FindBufferCondPair(Value buffer, Value aliasValue) {
  for (BufferCondPair bufferCondPair : buffer2AliasVec[buffer]) {
    if (bufferCondPair.first == aliasValue) {
      return &bufferCondPair;
    }
  }
  return std::nullopt;
}

SmallVector<BufferCondPair>
MemLivenessAnalysis::GetAliasBufferCondPairs(Value aliasBuffer) {
  auto trueVar = buffer2AliasVec.find(aliasBuffer);
  if (trueVar != buffer2AliasVec.end()) {
    return trueVar->second;
  }
  return {};
}

SetVector<Value> MemLivenessAnalysis::GetAliasBuffers(Value aliasBuffer) {
  SetVector<Value> aliasBuffers;
  auto aliasBufferPairVec = GetAliasBufferCondPairs(aliasBuffer);
  for (auto aliasBufferPair : aliasBufferPairVec) {
    aliasBuffers.insert(aliasBufferPair.first);
  }
  return aliasBuffers;
}

void MemLivenessAnalysis::UpdateStoreOpInfo(OpInfo *opInfo,
                                            const Value storeValue,
                                            Liveness live) {
  // The src of memref store may also serve as a gen buffer.
  SmallVector<Value, 1> storeValues;
  storeValues.push_back(storeValue);
  UpdateOpGenInfo(opInfo, storeValues);
  // Collect kill buffers corresponding to operation.
  OpKillHandle(opInfo, live, opInfo->operation->getBlock());
}

void MemLivenessAnalysis::UpdateOpBufferInfo(Operation *op,
                                             const ValueRange &results) {
  for (const Value &operand : results) {
    auto it = buffer2status.find(operand);
    if (it != buffer2status.end()) {
      continue;
    }
    bufferInfos[operand] = GenerateBufferInfo(op, operand);
    buffer2status[operand] = BufferStatus::DEFFINED;
  }
}

void MemLivenessAnalysis::UpdateOpGenInfo(OpInfo *opInfo,
                                          const ValueRange &results) {
  if (results.empty()) {
    return;
  }
  for (Value operand : results) {
    auto aliasBuffers = GetAliasBuffers(operand);
    aliasBuffers.insert(operand);
    for (auto buffer : aliasBuffers) {
      UpdateOperandGenInfo(opInfo, buffer);
    }
  }
}

void MemLivenessAnalysis::UpdateOperandGenInfo(OpInfo *opInfo, Value operand) {
  auto iter_buffer = buffer2status.find(operand);
  if (iter_buffer == buffer2status.end())
    return;
  if (iter_buffer->second == BufferStatus::DEFFINED) {
    genKillMap[opInfo].gen.push_back(operand);
    buffer2status[iter_buffer->first] = BufferStatus::GENED;
  } else if (iter_buffer->second == BufferStatus::KILLED) {
    llvm_unreachable("The buffer memory has been released and cannot be used "
                     "again! ");
  }
}

void MemLivenessAnalysis::OpKillHandle(OpInfo *opInfo, Liveness live,
                                       Block *block) {
  const auto *liveBlockInfo = live.getLiveness(block);
  auto currentLiveValues =
      liveBlockInfo->currentlyLiveValues(opInfo->operation);
  if (currentLiveValues.empty()) {
    return;
  }
  SetVector<Value> liveValues(currentLiveValues.begin(),
                              currentLiveValues.end());
  for (const Value &operand : liveValues) {
    UpdateOpKillInfo(opInfo, operand, live);
  }
}

void MemLivenessAnalysis::UpdateOpKillInfo(OpInfo *opInfo, Value operand,
                                           Liveness live) {
  auto aliasBuffers = GetAliasBuffers(operand);
  aliasBuffers.insert(operand);
  for (Value aliasBuffer : aliasBuffers) {
    auto iterBuffer = buffer2status.find(aliasBuffer);
    if (iterBuffer == buffer2status.end())
      return;
    if (iterBuffer->second == BufferStatus::GENED &&
        isParentOpDominate(iterBuffer->first.getDefiningOp(),
                           opInfo->operation) &&
        AllDeadAfter(opInfo->operation, aliasBuffers, live)) {
      genKillMap[opInfo].kill.push_back(aliasBuffer);
      buffer2status[iterBuffer->first] = BufferStatus::KILLED;
    }
  }
}

bool MemLivenessAnalysis::isParentOpDominate(Operation *op1,
                                             Operation *op2) const {
  assert((op1 != nullptr && op2 != nullptr) && "op must not be nullptr");
  return op2->getParentOp()->isAncestor(op1->getParentOp());
}

bool MemLivenessAnalysis::IsBlockAfter(Block *afterBlock,
                                       Block *beforeBlock) const {
  if (afterBlock == beforeBlock) {
    return false;
  }
  assert(afterBlock != nullptr && beforeBlock != nullptr);
  mlir::Region *region = beforeBlock->getParent();
  assert(region != nullptr);
  for (auto it = region->begin(); it != region->end(); ++it) {
    if (&*it == beforeBlock) {
      for (++it; it != region->end(); ++it) {
        if (&*it == afterBlock) {
          return true;
        }
      }
      break;
    }
  }
  return false;
}

bool MemLivenessAnalysis::IsDeadAfterBlock(Value value, Block *block) const {
  for (auto &useOperand : value.getUses()) {
    Operation *useOp = useOperand.getOwner();
    assert(useOp != nullptr);
    Block *useBlock = useOp->getBlock();
    if (useBlock != block && IsBlockAfter(useBlock, block)) {
      return false;
    }
  }
  return true;
}

bool MemLivenessAnalysis::AllDeadAfter(Operation *op, SetVector<Value> aliasVec,
                                       Liveness live) const {
  for (auto aliasBuffer : aliasVec) {
    if (!live.isDeadAfter(aliasBuffer, op) ||
        !IsDeadAfterBlock(aliasBuffer, op->getBlock())) {
      return false;
    }
  }
  return true;
}

BufferInfo MemLivenessAnalysis::GenerateBufferInfo(Operation *op,
                                                   Value operand) {
  auto memorySpaceAttr = GetBufferSpaceAttr(operand);
  if (isLocalMemPlan() && isLocalBuffer(memorySpaceAttr)) {
    assert(memorySpaceAttr.has_value() && "buffer must has space!");
    return GetBufferInfo(op, operand,
                         memorySpaceAttr.value().getAddressSpace());
  } else if (isGlobalWorkSpaceMemPlan() &&
             isa<bishengir::memref_ext::AllocWorkspaceOp>(
                 operand.getDefiningOp())) {
    return GetBufferInfo(op, operand, hivm::AddressSpace::GM);
  }
  llvm_unreachable("buffer must has BufferInfo !");
}

BufferInfo MemLivenessAnalysis::GetBufferInfo(Operation *op, Value operand,
                                              hivm::AddressSpace bufferScope) {
  BufferInfo bufferInfo;
  bufferInfo.operation = op;
  bufferInfo.bufferScope = bufferScope;
  // get buffer size, now for static shape
  Value traceValue = utils::tracebackMemRef(operand);
  auto memRefType = cast<MemRefType>(traceValue.getType());
  bufferInfo.bufferType = memRefType.getElementType();
  std::optional<int64_t> totalStaticSize =
      utils::getStaticTotalSize(memRefType.getShape());
  if (!totalStaticSize.has_value()) {
    llvm_unreachable(
        "Failed to obtain op buffer shape size which should be static.");
  }
  bufferInfo.constBits =
      totalStaticSize.value() *
      static_cast<int64_t>(memRefType.getElementTypeBitWidth());
  return bufferInfo;
}

void MemLivenessAnalysis::InitializeInplacePairList() {
  for (auto &bufferInfo : bufferInfos) {
    assert(memref_ext::isDefiningOpAllocLike(bufferInfo.first));
    auto iter = buffer2AliasVec.find(bufferInfo.first);
    if (iter == buffer2AliasVec.end()) {
      continue;
    }
    for (auto &aliasBufferPair : iter->second) {
      auto &aliasBuffer = aliasBufferPair.first;
      if (!memref_ext::isDefiningOpAllocLike(aliasBuffer)) {
        continue;
      }
      if (count(inplacePairList.begin(), inplacePairList.end(),
                std::make_pair(bufferInfo.first, aliasBuffer)) ||
          count(inplacePairList.begin(), inplacePairList.end(),
                std::make_pair(aliasBuffer, bufferInfo.first))) {
        continue;
      }
      if (aliasBufferPair.second) {
        continue;
      }
      inplacePairList.emplace_back(
          std::make_pair(bufferInfo.first, aliasBuffer));
      auto it = bufferInfos.find(aliasBuffer);
      assert(it != bufferInfos.end() && "buffer Info need define before! ");
      it->second.ignoreInplace = true;
      bufferInfo.second.ignoreInplace = true;
    }
  }
}

void MemLivenessAnalysis::GenerateBufferLife() {
  int scopeTime = 0;
  for (size_t i = 0; i < linearOperation.size(); ++i) {
    auto it = genKillMap.find(linearOperation[i].get());
    if (it == genKillMap.end()) {
      scopeTime++;
      continue;
    }
    // Time given to buffer start.
    for (const Value &genBuffer : it->second.gen) {
      std::unique_ptr<BufferLife> bufferLife =
          std::make_unique<BufferLife>(genBuffer);
      bufferLife->allocTime = scopeTime;
      buffer2Life[genBuffer] = std::move(bufferLife);
    }
    // Time given to buffer end.
    for (const Value &killBuffer : it->second.kill) {
      auto iter = buffer2Life.find(killBuffer);
      assert(iter != buffer2Life.end() &&
             "buffer has not been generated before! ");
      iter->second->freeTime = scopeTime;
    }
    scopeTime++;
  }
}

std::shared_ptr<BufferLife>
StorageEntry::GetBufferLifeByValue(const Value v) const {
  auto find = std::find_if(
      bufferLifeVec.begin(), bufferLifeVec.end(),
      [v](std::shared_ptr<BufferLife> life) { return life->buffer == v; });
  if (find != bufferLifeVec.end()) {
    return *find;
  }
  return nullptr;
}

bool MemPlan::IsReuseHIVMOp(Operation *op, const Value &genBuffer,
                            const Value &killBuffer) const {
  if (mlir::isa<hivm::VAddOp, hivm::VSubOp, hivm::VMaxOp, hivm::VMinOp,
                hivm::VOrOp, hivm::VAndOp, hivm::VMulOp>(op) &&
      !hasInlineBroadcastOrTransposeAttr(op)) {
    // above ops can be inplaced in isa
    return true;
  }

  if (restrictInplaceAsISA)
    return false;

  // not in ISA but confirmed with hardware developers:
  // elementwise ops with the same shape and the same bitwidth operands can also
  // do memory inplace for src and dst

  if (!isReusableExtraBuffer(op))
    return false;

  auto hivmOp = dyn_cast<hivm::HIVMStructuredOp>(op);
  if (!hivmOp || !mlir::hivm::detail::isElemwiseNaryOpImpl(op))
    return false;

  // Tranpose operands can not inplace even with same shape
  if (hivmOp.existInlineTransposeLoopDims())
    return false;

  if (hivmOp.existInlineBroadcastLoopDims())
    return isStaticShapeSame(op, genBuffer, killBuffer);

  if (op->hasTrait<OpTrait::SameOperandsElementType>())
    return true;

  return isReusableOperands(op, hivmOp);
}

SmallVector<ValuePair> MemPlan::GenerateInplaceList() {
  SmallVector<ValuePair> inplaceList;
  DenseMap<Operation *, bool> hasTouchOp;
  inplaceList.insert(inplaceList.end(), inplacePairList.begin(),
                     inplacePairList.end());
  for (auto &operationSeq : linearOperation) {
    auto it = genKillMap.find(operationSeq.get());
    if (it == genKillMap.end())
      continue;
    if (hasTouchOp[operationSeq->operation]) {
      continue;
    }
    for (const Value &genBuffer : it->second.gen) {
      auto genBufferIter = bufferInfos.find(genBuffer);
      assert(genBufferIter != bufferInfos.end() &&
             "genBuffer should be find in bufferInfos");
      if (genBufferIter->second.ignoreInplace) {
        continue;
      }
      for (const Value &killBuffer : it->second.kill) {
        auto killBufferIter = bufferInfos.find(killBuffer);
        assert(killBufferIter != bufferInfos.end() &&
               "killBuffer should be find in bufferInfos");
        if (killBufferIter->second.ignoreInplace) {
          continue;
        }
        bool canInplace =
            killBufferIter->second.constBits >=
                genBufferIter->second.constBits &&
            IsReuseHIVMOp(it->first->operation, genBuffer, killBuffer);
        if (canInplace) {
          inplaceList.emplace_back(std::make_pair(genBuffer, killBuffer));
          break;
        }
      }
    }
    // Nodes in inplace are only processed once.
    hasTouchOp[operationSeq->operation] = true;
  }
  return inplaceList;
}

void MemPlan::EmitPlanMemoryFailureInfo() {
  if (failApplyBufferInfo.empty())
    return;
  for (auto &iter : failApplyBufferInfo) {
    AddressSpace space = iter.first;
    func_.emitError() << stringifyEnum(space) << " overflow, requires "
                      << iter.second << " bits while "
                      << GetBufferSpaceInfo(space).second << " bits available! "
                      << "(possible reason: tiling basic block is too large "
                      << "or block number is more than what user expect due to multi-buffer feature is enabled "
                      << "and some ops need extra local buffer.)";
  }
}

// Plan Memory algorithm.
LogicalResult MemPlan::plan() {
  // Construct StorageEntry structure.
  GenerateStorageEntry();
  // Plan memory address.
  PlanStatus as = planMode == MemPlanMode::LOCAL_MEM_PLAN
                      ? PlanLocalMemAddress()
                      : PlanWorkSpaceMemAddress();
  if (as == PlanStatus::PLAN_FAILED) {
    EmitPlanMemoryFailureInfo();
    return failure();
  }
  // Update the address information of each buffer after memory buffer.
  UpdateBuffer2Offsets();
  return success();
}

void MemPlan::GenerateStorageEntry() {
  // create new storage entry.
  for (auto &operation : linearOperation) {
    auto it = genKillMap.find(operation.get());
    if (it == genKillMap.end())
      continue;
    for (const Value &genBuffer : it->second.gen) {
      auto iter = bufferInfos.find(genBuffer);
      if (iter == bufferInfos.end()) {
        continue;
      }
      const std::shared_ptr<BufferLife> &bufLife = buffer2Life.at(genBuffer);
      std::unique_ptr<StorageEntry> entry = std::make_unique<StorageEntry>();
      entry->bufInfo = &iter->second;
      entry->bufferLifeVec.emplace_back(bufLife);
      entry->inplaceBuffers.emplace_back(iter->first);
      auto multiBuffer = buffer2MultiNum.find(genBuffer);
      if (multiBuffer != buffer2MultiNum.end()) {
        entry->multiBufferNum = multiBuffer->second;
      }
      buffer2storageEntry[genBuffer] = entry.get();
      // Verify the validity of parameters after initialization.
      ValidateParameters(entry);
      StorageEntryVec.emplace_back(std::move(entry));
    }
  }
}

void MemPlan::ValidateParameters(std::unique_ptr<StorageEntry> &e) const {
  assert(e->bufInfo->operation && "Unrecognized legal define operation !");
  assert(e->bufInfo->constBits >= 0U && "recognized illegal memory sizes !");
  assert(!e->bufferLifeVec.empty() && "Unrecognized buffer's life time !");
}

void MemPlan::UpdateBuffer2Offsets() {
  for (auto &e : StorageEntryVec) {
    for (Value &buffer : e->inplaceBuffers) {
      // MultiBuffer can cause multiple addrs.
      buffer2Offsets[buffer].push_back(
          (e->bitsOffset + utils::kBitsToByte - 1) / utils::kBitsToByte);
    }
  }
  // In the MultiBuffer scenario, single reuse db will result in additional
  // storageEntry.
  UpdateMultiBufferReuseExtraOffset();
}

void MemPlan::UpdateMultiBufferReuseExtraOffset() {
  if (pingEntry2RelationPongEntry.empty()) {
    return;
  }

  for (auto &relationEntry : pingEntry2RelationPongEntry) {
    for (Value &buffer : relationEntry.second->inplaceBuffers) {
      // MultiBuffer can cause multiple addrs.
      buffer2Offsets[buffer].push_back(
          (relationEntry.second->bitsOffset + utils::kBitsToByte - 1) /
          utils::kBitsToByte);
    }
  }
}

void MemPlan::MergeInplaceSE() {
  // get the list of inplace value pair.
  SmallVector<ValuePair> inplaceList = GenerateInplaceList();
  // try to merge storage entries. genSE is replaced by KillSE.
  for (const auto &pairIter : inplaceList) {
    const StorageEntry *genSE = buffer2storageEntry[pairIter.first];
    StorageEntry *killSE = buffer2storageEntry[pairIter.second];
    if (genSE == killSE) {
      // already same storageEntry, no need to inplace.
      continue;
    }
    assert(genSE != nullptr && killSE != nullptr &&
           " genSE and killSE should be valid");
    BufferLifeVec mergedBufferLifeVec;
    mergedBufferLifeVec.insert(mergedBufferLifeVec.end(),
                               genSE->bufferLifeVec.begin(),
                               genSE->bufferLifeVec.end());
    mergedBufferLifeVec.insert(mergedBufferLifeVec.end(),
                               killSE->bufferLifeVec.begin(),
                               killSE->bufferLifeVec.end());
    MergeBufferVec(mergedBufferLifeVec);
    killSE->bufferLifeVec.swap(mergedBufferLifeVec);
    //  merge allocs of two storage entry and update inplaceBuffers.
    killSE->inplaceBuffers.insert(killSE->inplaceBuffers.begin(),
                                  genSE->inplaceBuffers.begin(),
                                  genSE->inplaceBuffers.end());

    // Take the maximum value for the inplace scene.
    killSE->multiBufferNum =
        std::max(genSE->multiBufferNum, killSE->multiBufferNum);

    // all buffers have same storage entry after merging
    for (auto &buffer : genSE->inplaceBuffers) {
      buffer2storageEntry[buffer] = killSE;
    }
    // remove the alloc info of dst after successful merging
    auto e = std::find_if(StorageEntryVec.begin(), StorageEntryVec.end(),
                          [genSE](std::unique_ptr<StorageEntry> &se) {
                            return se.get() == genSE;
                          });
    StorageEntryVec.erase(e);
  }
}

PlanStatus MemPlan::PlanLocalMemAddress() {
  // merge from the first storage entry
  MergeInplaceSE();
  dmaFirstPipelineOpt.build(func_);
  ExpandMultiBufferStorageEntry();
  MergeSameScopeSE();
  return PlanMemAddressOfWholeLocalBuffer();
}

PlanStatus MemPlan::PlanWorkSpaceMemAddress() {
  // merge from the first storage entry
  MergeInplaceSE();
  ExpandMultiBufferStorageEntry();
  // Construct root StorageEntry and collect the same work space arg
  // StorageEntry.
  MergeSameWorkSpaceArgSE();
  // Start plan
  return PlanMemOffsetOfWholeWorkSpace();
}

void MemPlan::MergeSameWorkSpaceArgSE() {
  for (auto &iter : StorageEntryVec) {
    auto allocWorkspaceOp = dyn_cast<bishengir::memref_ext::AllocWorkspaceOp>(
        iter->bufInfo->operation);
    assert(allocWorkspaceOp && "only allocWorkspaceOp will plan");
    Value workspaceArg = allocWorkspaceOp.getWorkspaceArg();
    auto iter_arg = workSpaceArg2rootStorageEntry.find(workspaceArg);
    if (iter_arg == workSpaceArg2rootStorageEntry.end()) {
      workSpaceArg2rootStorageEntry[workspaceArg] = iter.get();
    } else {
      iter_arg->second->mergedChildren.push_back(iter.get());
    }
  }
}

PlanStatus MemPlan::PlanMemOffsetOfWholeWorkSpace() {
  for (auto &it : workSpaceArg2rootStorageEntry) {
    StorageEntry *rootStorageEntry = it.second;
    assert(rootStorageEntry && "Root StorageEntry should not be null");
    if (!enableGlobalReuse) {
      GlobalWorkspaceNoReuse(rootStorageEntry);
      continue;
    }
    MemBoundList outline;
    PlanRecHis history;
    SpecInfo si;
    // Can be reuse without conflicting life intervals.
    si.specLevel = si.minLevel;
    int childrenNum = static_cast<int>(rootStorageEntry->mergedChildren.size());
    outline.push_back(std::make_shared<MemoryBound>(
        BufferLifeVec(), 0, std::numeric_limits<uint64_t>::max(), nullptr));

    // The initial value is rootStorageEntry.
    StorageEntry *curEntry = rootStorageEntry;
    while (si.childIdx < childrenNum) {
      curEntry->alignedConstBits =
          static_cast<uint64_t>(curEntry->bufInfo->constBits);
      curEntry->childIdx = si.childIdx;
      LogicalResult planResult = MultiSpecPlan(si, outline, history, curEntry);
      if (failed(planResult)) {
        return PlanStatus::PLAN_FAILED;
      }
      if (si.childIdx >= childrenNum) {
        break;
      }
      curEntry = rootStorageEntry->mergedChildren[si.childIdx];
    }
  }
  planStatus = PlanStatus::PLAN_SUCCESS;
  return planStatus;
}

void MemPlan::GlobalWorkspaceNoReuse(StorageEntry *rootStorageEntry) {
  rootStorageEntry->bitsOffset = 0;
  uint64_t offset = static_cast<uint64_t>(rootStorageEntry->bufInfo->constBits);
  for (StorageEntry *child : rootStorageEntry->mergedChildren) {
    child->bitsOffset = offset;
    offset += static_cast<uint64_t>(child->bufInfo->constBits);
  }
}

void MemPlan::ExpandMultiBufferStorageEntry() {
  // StorageEntry that needs to be expanded.
  size_t size = StorageEntryVec.size();
  for (size_t i = 0; i < size; i++) {
    if (StorageEntryVec[i]->multiBufferNum > 1) {
      std::unique_ptr<StorageEntry> entry = std::make_unique<StorageEntry>();
      entry->bufInfo = StorageEntryVec[i]->bufInfo;
      entry->bufferLifeVec = StorageEntryVec[i]->bufferLifeVec;
      entry->alignedConstBits = StorageEntryVec[i]->alignedConstBits;
      entry->inplaceBuffers = StorageEntryVec[i]->inplaceBuffers;
      entry->multiBufferNum = StorageEntryVec[i]->multiBufferNum;
      // Ping saves information related to Pong.
      StorageEntryVec[i]->relationPongEntry = entry.get();
      StorageEntryVec.push_back(std::move(entry));
    }
  }
}

bool MemPlan::IsEnoughForBuffersNoReuse(StorageEntry *rootStorageEntry,
                                        size_t restBufferSize,
                                        size_t alignUnit) {
  auto iter =
      bufferScope2RequiredSize.find(rootStorageEntry->bufInfo->bufferScope);
  assert(iter != bufferScope2RequiredSize.end());
  if (iter->second < restBufferSize) {
    PlanBuffersWithoutReuse(rootStorageEntry, alignUnit);
    return true;
  }
  return false;
}

void MemPlan::PlanBuffersWithoutReuse(StorageEntry *rootStorageEntry,
                                      size_t alignUnit) {
  uint offset = 0;
  rootStorageEntry->bitsOffset = offset;
  offset = AlignUp(rootStorageEntry->bufInfo->constBits, alignUnit);
  for (StorageEntry *child : rootStorageEntry->mergedChildren) {
    child->bitsOffset = offset;
    offset += AlignUp(child->bufInfo->constBits, alignUnit);
  }
}

void MemPlan::MergeSameScopeSE() {
  // Construct root StorageEntry and collect the same scope StorageEntry
  for (auto &iter : StorageEntryVec) {
    auto iter_scope =
        memscope2rootStorageEntry.find(iter->bufInfo->bufferScope);
    if (iter_scope == memscope2rootStorageEntry.end()) {
      memscope2rootStorageEntry[iter->bufInfo->bufferScope] = iter.get();
    } else {
      iter_scope->second->mergedChildren.push_back(iter.get());
    }
  }

  // set bufferScope2RequiredSize for all StorageEntry
  for (auto &rootStorageEntry : memscope2rootStorageEntry) {
    auto bufferSpaceInfo = GetBufferSpaceInfo(rootStorageEntry.first);
    size_t accumulateSize = AlignUp(rootStorageEntry.second->bufInfo->constBits,
                                    bufferSpaceInfo.first);
    for (auto &childrenStorageEntry : rootStorageEntry.second->mergedChildren) {
      size_t curStorageSize = AlignUp(childrenStorageEntry->bufInfo->constBits,
                                      bufferSpaceInfo.first);
      accumulateSize = accumulateSize + curStorageSize;
    }
    bufferScope2RequiredSize[rootStorageEntry.first] = accumulateSize;
  }
}

void mlir::hivm::MemPlan::PlanMemAddressForLevel0(
    StorageEntry *rootStorageEntry) {
  // get the buffer info for a given scope.
  auto bufferSpaceInfo =
      GetBufferSpaceInfo(rootStorageEntry->bufInfo->bufferScope);
  size_t align = bufferSpaceInfo.first;
  size_t maxBits = UINT64_MAX;
  rootStorageEntry = GetReorderRootStorageEntry(rootStorageEntry);
  // memory outline in a given buffer scope.
  MemBoundList outline;
  PlanRecHis history;
  SpecInfo si;
  si.specLevel = SPEC_LEVEL_0;
  si.maxLevel = SPEC_LEVEL_0;
  int childrenNum = static_cast<int>(rootStorageEntry->mergedChildren.size());
  outline.push_back(
      std::make_shared<MemoryBound>(BufferLifeVec(), 0, maxBits, nullptr));

  // The initial value is rootStorageEntry.
  StorageEntry *curEntry = rootStorageEntry;
  while (si.childIdx < childrenNum) {
    uint64_t needBits = static_cast<uint64_t>(curEntry->bufInfo->constBits);
    curEntry->alignedConstBits = AlignUp(needBits, align);
    curEntry->childIdx = si.childIdx;
    (void)MultiSpecPlan(si, outline, history, curEntry);
    if (si.childIdx >= childrenNum) {
      break;
    }
    curEntry = rootStorageEntry->mergedChildren[si.childIdx];
  }
  // Find the max appled bits from all children and root, which is the max
  // memory applied in this buffer space.
  uint64_t maxAllocBits = rootStorageEntry->alignedConstBits;
  auto children = rootStorageEntry->mergedChildren;
  for (auto *child : children) {
    maxAllocBits =
        std::max(maxAllocBits, child->bitsOffset + child->alignedConstBits);
  }
  failApplyBufferInfo[rootStorageEntry->bufInfo->bufferScope] = maxAllocBits;
}

PlanStatus MemPlan::PlanMemAddressOfWholeLocalBuffer() {
  // Start plan
  for (auto &it : memscope2rootStorageEntry) {
    StorageEntry *rootStorageEntry = it.second;
    assert(rootStorageEntry && "Root StorageEntry should not be null");
    // get the buffer info for a given scope.
    auto bufferSpaceInfo =
        GetBufferSpaceInfo(rootStorageEntry->bufInfo->bufferScope);
    size_t align = bufferSpaceInfo.first;
    size_t maxBits = bufferSpaceInfo.second;
    if (rootStorageEntry->mergedChildren.empty()) {
      // Only one buffer needs to be allocated within the same scope, allocate
      // directly.
      uint64_t needAlignedBits =
          AlignUp(rootStorageEntry->bufInfo->constBits, align);
      if (needAlignedBits > maxBits) {
        failApplyBufferInfo[rootStorageEntry->bufInfo->bufferScope] =
            needAlignedBits;
        return PlanStatus::PLAN_FAILED;
      }
      rootStorageEntry->bitsOffset = 0;
      continue;
    }
    if (IsEnoughForBuffersNoReuse(rootStorageEntry, maxBits, align)) {
      continue;
    }
    rootStorageEntry = GetReorderRootStorageEntry(rootStorageEntry);
    ReportMemLifeDebugInfo(rootStorageEntry);
    // memory outline in a given buffer scope.
    MemBoundList outline;
    PlanRecHis history;
    SpecInfo si;
    si.specLevel = si.maxLevel;
    int childrenNum = static_cast<int>(rootStorageEntry->mergedChildren.size());
    outline.push_back(
        std::make_shared<MemoryBound>(BufferLifeVec(), 0, maxBits, nullptr));

    // The initial value is rootStorageEntry.
    StorageEntry *curEntry = rootStorageEntry;
    while (si.childIdx < childrenNum) {
      uint64_t needBits = static_cast<uint64_t>(curEntry->bufInfo->constBits);
      curEntry->alignedConstBits = AlignUp(needBits, align);
      curEntry->childIdx = si.childIdx;
      LDBG("\n");
      LDBG("----------Need-Plan-CurEntry---------\n");
      ReportCurEntryDebugInfo(curEntry);
      LDBG("\n");
      LogicalResult planResult = MultiSpecPlan(si, outline, history, curEntry);
      if (failed(planResult)) {
        StatusWrapper statusWrapper = {false,   curEntry->alignedConstBits,
                                       &si,     outline,
                                       history, rootStorageEntry};
        LDBG("\n");
        LDBG("----------ApplyFailStrategy---------\n");
        ReportCurEntryDebugInfo(curEntry);
        LDBG("\n");
        PlanStatus as = ApplyFailStrategy(statusWrapper, maxBits);
        if (as == PlanStatus::RESTART_NEW_PLAN) {
          // Restart plan.
          si = SpecInfo();
          curEntry = rootStorageEntry;
          continue;
        }
        if (as == PlanStatus::PLAN_FAILED) {
          ReportAllocatedEntryDebugInfo(rootStorageEntry);
          PlanMemAddressForLevel0(rootStorageEntry);
          return as;
        }
      }
      if (si.childIdx >= childrenNum) {
        break;
      }
      curEntry = rootStorageEntry->mergedChildren[si.childIdx];
    }
  }
  planStatus = PlanStatus::PLAN_SUCCESS;
  return planStatus;
}

void MemPlan::ReportMemLifeDebugInfo(const StorageEntry *rootStorageEntry) {
  LDBG("-------------------------- Buffer2Life --------------------------\n");
  MemLifeDebugInfo(rootStorageEntry);
  for (auto &StorageEntry : rootStorageEntry->mergedChildren) {
    MemLifeDebugInfo(StorageEntry);
  }
}

void MemPlan::MemLifeDebugInfo(const StorageEntry *storageEntry) const {
  for (auto &buffer : storageEntry->inplaceBuffers) {
    if (buffer.getDefiningOp()) {
      if (auto allocOp = dyn_cast<memref::AllocOp>(buffer.getDefiningOp())) {
        LDBG("Buffer : " << allocOp.getResult() << "\n");
      }
    }
  }
#ifndef NDEBUG
  for (auto &bufferLife : storageEntry->bufferLifeVec) {
    LDBG("bufferLife : " << "allocTime : " << bufferLife->allocTime
                         << " , freeTime : " << bufferLife->freeTime << "\n");
  }
  LDBG("\n");
#endif
}

void MemPlan::ReportCurEntryDebugInfo(const StorageEntry *curEntry) const {
  for (auto &buffer : curEntry->inplaceBuffers) {
    if (buffer.getDefiningOp()) {
      if (auto allocOp = dyn_cast<memref::AllocOp>(buffer.getDefiningOp())) {
        LDBG("buffer : ");
        LDBG(allocOp.getResult());
      }
    }
  }
}

StorageEntry *
MemPlan::GetReorderRootStorageEntry(StorageEntry *rootStorageEntry) {
  if (rootStorageEntry->bufInfo->bufferScope != hivm::AddressSpace::UB) {
    return rootStorageEntry;
  }
  SmallVector<StorageEntry *> origStorageEntryVec = {rootStorageEntry};
  origStorageEntryVec.insert(origStorageEntryVec.end(),
                             rootStorageEntry->mergedChildren.begin(),
                             rootStorageEntry->mergedChildren.end());

  // reorder storage entrys: dma touched buffers + other buffers + scalar
  // touched buffers
  SmallVector<StorageEntry *> reorderedStorageEntryVec;
  SmallVector<StorageEntry *> touchPipeScalarStorageEntryVec;
  for (auto &storageEntry : origStorageEntryVec) {
    for (auto &buffer : storageEntry->inplaceBuffers) {
      if (dmaFirstPipelineOpt.IsDmaBuffer(buffer)) {
        reorderedStorageEntryVec.push_back(storageEntry);
        break;
      }
      if (dmaFirstPipelineOpt.IsScalarBuffer(buffer)) {
        touchPipeScalarStorageEntryVec.push_back(storageEntry);
        break;
      }
    }
  }
  for (auto &storageEntry : origStorageEntryVec) {
    auto it1 = std::find(reorderedStorageEntryVec.begin(),
                         reorderedStorageEntryVec.end(), storageEntry);
    auto it2 = std::find(touchPipeScalarStorageEntryVec.begin(),
                         touchPipeScalarStorageEntryVec.end(), storageEntry);
    if (it1 == reorderedStorageEntryVec.end() &&
        it2 == touchPipeScalarStorageEntryVec.end()) {
      reorderedStorageEntryVec.push_back(storageEntry);
    }
  }

  reorderedStorageEntryVec.insert(reorderedStorageEntryVec.end(),
                                  touchPipeScalarStorageEntryVec.begin(),
                                  touchPipeScalarStorageEntryVec.end());

  // Ensure that ping pong is continuously plan mem in the multi buffer.
  ReorderContinuousPingPongEntry(reorderedStorageEntryVec);
  StorageEntry *reorderedRootStorageEntry = reorderedStorageEntryVec[0];
  reorderedRootStorageEntry->mergedChildren.clear();
  for (size_t j = 1; j < reorderedStorageEntryVec.size(); ++j) {
    reorderedRootStorageEntry->mergedChildren.push_back(
        reorderedStorageEntryVec[j]);
  }
  return reorderedRootStorageEntry;
}

void MemPlan::ReorderContinuousPingPongEntry(
    SmallVector<StorageEntry *> &storageEntryVec) {
  SmallVector<StorageEntry *> reorderedStorageEntryVec;
  for (auto &storageEntry : storageEntryVec) {
    auto it = std::find(reorderedStorageEntryVec.begin(),
                        reorderedStorageEntryVec.end(), storageEntry);
    if (it == reorderedStorageEntryVec.end()) {
      reorderedStorageEntryVec.push_back(storageEntry);
      if (storageEntry->multiBufferNum == 2 &&
          storageEntry->relationPongEntry) {
        // Ping Pong continuous save.
        reorderedStorageEntryVec.push_back(storageEntry->relationPongEntry);
      }
    }
  }
  reorderedStorageEntryVec.swap(storageEntryVec);
}

std::pair<size_t, size_t>
MemPlan::GetBufferSpaceInfo(hivm::AddressSpace &space) const {
  switch (space) {
  case hivm::AddressSpace::UB:
    return std::make_pair(ubAlignSize, ubSpaceSize);
  case hivm::AddressSpace::L1:
    return std::make_pair(l1AlignSize, l1SpaceSize);
  case hivm::AddressSpace::L0C:
    return std::make_pair(l0cAlignSize, l0cSpaceSize);
  default:
    llvm_unreachable("Temporarily unsupported memory buffer space !");
  }
}

LogicalResult MemPlan::MultiSpecPlan(SpecInfo &si, MemBoundList &outline,
                                     PlanRecHis &history, StorageEntry *entry) {
  LogicalResult planResult = failure();
  for (int i = si.specLevel; i >= si.minLevel; i--) {
    planResult = SpecAlloc(outline, history, entry, si, i);
    if (succeeded(planResult)) {
      if (si.childIdx == si.specStartIdx) {
        // In roll back plan, when the specified specStartIdx is reached,
        // the subsequent plan still adopts the maxLevel strategy.
        si.specLevel = si.maxLevel;
      }
      si.childIdx++;
      break;
    }
  }
  return planResult;
}

void MemPlan::RecordAllocatedEntry(const StorageEntry *e) {
  assert(e && e->bufInfo);
  auto it = memscope2allocatedEntry.find(e->bufInfo->bufferScope);
  if (it == memscope2allocatedEntry.end()) {
    SmallVector<const StorageEntry *> allocatedEntry;
    allocatedEntry.push_back(e);
    memscope2allocatedEntry.insert(
        std::make_pair(e->bufInfo->bufferScope, allocatedEntry));
    return;
  }
  auto &allocatedEntry = it->second;
  bool needRecord = allocatedEntry.end() ==
                    std::find(allocatedEntry.begin(), allocatedEntry.end(), e);
  if (needRecord) {
    allocatedEntry.push_back(e);
  }
}

LogicalResult MemPlan::SpecAlloc(MemBoundList &outline, PlanRecHis &his,
                                 StorageEntry *e, const SpecInfo &si,
                                 int localLevel) {
  if (std::any_of(his.begin(), his.end(),
                  [e](PlanRecord &r) { return r.entry && r.entry == e; })) {
    // If the plan has already been completed, return success directly.
    return success();
  }
  assert(e && "StorageEntry should not be null");
  for (MemBoundListConstIter start = outline.begin(); start != outline.end();
       ++start) {
    uint64_t size = 0;
    uint64_t allocOffset = (*start)->offset;
    for (MemBoundListConstIter end = start; end != outline.end(); ++end) {
      std::shared_ptr<MemoryBound> last = *end;
      size += last->extent;
      // if index & addr are as same as last rollback result,
      // continue to find next result
      if (IsSamePlanAsLastRollBack(allocOffset, e->childIdx, si) ||
          VerifyConflictStage0(e, last)) {
        start = end;
        break;
      }
      if (size < e->alignedConstBits) {
        continue;
      }
      // If SPEC_LEVEL_1, then the address of pong Offset address needs to be
      // allocated.
      uint64_t pongOffset{0};
      if (localLevel == SPEC_LEVEL_1 &&
          VerifyConflictStage1(outline, his, e,
                               OutlineSectionInfo(start, end, size, false),
                               pongOffset)) {
        break;
      }

      if (VerifyConflictStage2(his, e, localLevel, start, outline)) {
        break;
      }
      e->bitsOffset = allocOffset;
      UpdateOutline(outline, his, e,
                    OutlineSectionInfo(start, end, size, false), localLevel);

      if (localLevel == SPEC_LEVEL_1) {
        // There is no conflict with the historical plan of buffer life, and
        // the address of the Pong can be assigned.
        PlanRelationPongEntryAddress(pongOffset, e);
        SpecAllocRelationPongEntry(outline, his, e, pongOffset);
      }
      LDBG("APPLY_SPEC_LEVEL:  " << localLevel << "\n");
      RecordAllocatedEntry(e);
      return success();
    }
  }
  return failure();
}

LoopLikeOpInterface
MemPlan::GetBufferParentLoop(const SmallVector<Value> &buffers) {
  llvm::SmallSet<LoopLikeOpInterface, 1> parentLoopVec;
  for (auto buffer : buffers) {
    if (!buffer.getDefiningOp()) {
      assert(isa<scf::ForOp>(buffer.getParentBlock()->getParentOp()) ||
             isa<scf::WhileOp>(buffer.getParentBlock()->getParentOp()));
      // Init args and region iter arg are inplace, ignore Region Iter Arg
      // without DefineOp.
      continue;
    }
    LoopLikeOpInterface bufferParentLoop = getParentLoop(buffer);
    if (bufferParentLoop) {
      parentLoopVec.insert(bufferParentLoop);
    } else {
      return nullptr;
    }
  }
  if (parentLoopVec.size() == 1) {
    return *parentLoopVec.begin();
  }
  return nullptr;
}

bool MemPlan::VerifyConflictStage1(MemBoundList &outline, PlanRecHis &his,
                                   StorageEntry *e,
                                   const OutlineSectionInfo &outlineInfo,
                                   uint64_t &pongOffset) {
  if (outlineInfo.mem_start != outlineInfo.mem_end) {
    return true;
  }
  auto reuseBoundStorageEntry = (*outlineInfo.mem_start)->lastStorageEntry;
  if (!reuseBoundStorageEntry) {
    // This area has not been planed, so there is no need to consider it.
    return true;
  }

  StorageEntry *multiRelationPongEntry =
      GetMultiRelationPongEntry(reuseBoundStorageEntry);
  if (multiRelationPongEntry) {
    if (e->multiBufferNum == 1 ||
        (e->multiBufferNum == 2 && e->relationPongEntry &&
         (e->relationPongEntry->bitsOffset != 0))) {
      auto parentLoop1 = GetBufferParentLoop(e->inplaceBuffers);
      auto parentLoop2 =
          GetBufferParentLoop(reuseBoundStorageEntry->inplaceBuffers);
      if (!(parentLoop1 != nullptr && parentLoop2 != nullptr &&
            parentLoop1 == parentLoop2)) {
        // Cannot be reused under the same loop.
        return true;
      }
      // There are two situations:
      // 1. Single reuse DB.
      // 2. DB reuse DB.
      pongOffset = multiRelationPongEntry->bitsOffset;
      bool conflict = std::any_of(
          his.begin(), his.end(), [pongOffset, e, this](PlanRecord &r) {
            return this->IsBufferLifeVecConflict(r, pongOffset, e);
          });
      if (!conflict) {
        return false;
      }
    }
  }
  return true;
}

StorageEntry *
MemPlan::GetMultiRelationPongEntry(const StorageEntry *reuseBoundStorageEntry) {
  if (reuseBoundStorageEntry->multiBufferNum == 2 &&
      reuseBoundStorageEntry->relationPongEntry &&
      (reuseBoundStorageEntry->relationPongEntry->bitsOffset != 0)) {
    // If the reuseBoundStorageEntry itself requires db, directly match and
    // return relationPongEntry.
    return reuseBoundStorageEntry->relationPongEntry;
  }
  auto iter = pingEntry2RelationPongEntry.find(reuseBoundStorageEntry);
  if (iter != pingEntry2RelationPongEntry.end()) {
    // If the reuseBoundStorageEntry itself is single, but has already been
    // reused with db and has an extra pong StorageEntry is added.
    return iter->second.get();
  }
  return nullptr;
}

void MemPlan::SpecAllocRelationPongEntry(MemBoundList &outline, PlanRecHis &his,
                                         StorageEntry *e, uint64_t offset) {
  for (MemBoundListConstIter start = outline.begin(); start != outline.end();
       ++start) {
    uint64_t size = 0;
    // Find the MemBound corresponding to the Pong offset.
    if ((*start)->offset != offset) {
      continue;
    }
    for (MemBoundListConstIter end = start; end != outline.end(); ++end) {
      std::shared_ptr<MemoryBound> last = *end;
      size += last->extent;
      if (size < e->alignedConstBits) {
        continue;
      }
      StorageEntry *pongStorageEntry = nullptr;
      auto iter = pingEntry2RelationPongEntry.find(e);
      if (iter != pingEntry2RelationPongEntry.end()) {
        pongStorageEntry = iter->second.get();
      }
      if (e->multiBufferNum == 2 && e->relationPongEntry) {
        pongStorageEntry = e->relationPongEntry;
      }
      assert(pongStorageEntry && "PongStorage Entry not found!");
      UpdateOutline(outline, his, pongStorageEntry,
                    OutlineSectionInfo(start, end, size, true), SPEC_LEVEL_1);
      return;
    }
  }
}

bool MemPlan::IsBufferLifeVecConflict(PlanRecord &r, uint64_t offset,
                                      const StorageEntry *e) const {
  if ((r.firstMemBound->offset + r.allExtent > offset) &&
      (r.firstMemBound->offset < offset + e->alignedConstBits)) {
    DenseMap<ValuePair, BufferLife> intersection =
        GetOverlapBufferLife(r.entry->bufferLifeVec, e->bufferLifeVec);
    return !intersection.empty();
  }
  return false;
}

void MemPlan::PlanRelationPongEntryAddress(uint64_t offset, StorageEntry *e) {
  if (e->multiBufferNum == 1) {
    std::unique_ptr<StorageEntry> entry = std::make_unique<StorageEntry>();
    entry->bufInfo = e->bufInfo;
    entry->bufferLifeVec = e->bufferLifeVec;
    entry->alignedConstBits = e->alignedConstBits;
    entry->inplaceBuffers = e->inplaceBuffers;
    entry->multiBufferNum = e->multiBufferNum;
    entry->bitsOffset = offset;
    pingEntry2RelationPongEntry[e] = std::move(entry);
  } else if (e->multiBufferNum == 2) {
    e->relationPongEntry->bitsOffset = offset;
  } else {
    llvm_unreachable("Does not support multi buffer num greater than 2 !");
  }
}

bool MemPlan::VerifyConflictStage2(PlanRecHis &his, const StorageEntry *e,
                                   int specLevel, MemBoundListConstIter &start,
                                   const MemBoundList &outline) {
  if (specLevel != SPEC_LEVEL_2) {
    return false;
  }
  bool touchMemCanUse = false;
  MemBoundListConstIter foundMem;

  for (auto iter = start; iter != outline.end(); ++iter) {
    uint64_t offset = (*iter)->offset;
    bool conflict =
        std::any_of(his.begin(), his.end(), [offset, e, this](PlanRecord &r) {
          return (r.firstMemBound->offset + r.allExtent > offset) &&
                 (r.firstMemBound->offset < offset + e->alignedConstBits) &&
                 this->PipeConflict(r.entry, e, this->pipeDmaConflictMap);
        });
    // if conflict, continue finding the first bound that has no conflict
    // if last bound do not meet the size, continue
    if (conflict ||
        ((*iter == outline.back()) && (*iter)->extent < e->alignedConstBits)) {
      continue;
    }
    touchMemCanUse = true;
    foundMem = iter;
    break;
  }

  if (touchMemCanUse) {
    bool conflict = (foundMem != start);
    start = conflict ? --foundMem : start;
    return conflict;
  }
  // if cannot find a bound that has no conflict with current entry,
  return true;
}

bool MemPlan::PipeConflict(const StorageEntry *e1, const StorageEntry *e2,
                           DenseMap<StorageEntryPair, bool> &conflictMap) {
  if (e1 == nullptr || e2 == nullptr) {
    return false;
  }
  auto sePair = std::make_pair(e1, e2);
  auto [iter, isInserted] = conflictMap.try_emplace(sePair, false);
  if (!isInserted) {
    return iter->second;
  }

  for (const Value var1 : e1->inplaceBuffers) {
    for (const Value var2 : e2->inplaceBuffers) {
      bool conflict = dmaFirstPipelineOpt.BufferPipeConflict(var1, var2);
      if (conflict) {
        iter->second = true;
        return true;
      }
    }
  }
  return false;
}

void MemPlan::UpdateOutline(MemBoundList &outline, PlanRecHis &his,
                            StorageEntry *e,
                            const OutlineSectionInfo &outlineInfo,
                            int localLevel) const {
  if (e == nullptr) {
    return;
  }
  auto start = outlineInfo.mem_start;
  MemBoundListConstIter end = outlineInfo.mem_end;
  // outline:
  // |-------start+end-------------|
  // |--head--|--split e--|--tail--|
  uint64_t res = outlineInfo.size - e->alignedConstBits;
  std::shared_ptr<MemoryBound> last = *end;
  ++end;
  std::shared_ptr<MemoryBound> bound;
  SmallVector<std::shared_ptr<MemoryBound>> splitBound;
  // split e, to get Boundbound
  if (splitOutline) {
    // add splitBound by splitting e to section
    AddMemBoundInSectionalWay(e, start, end, splitBound);
  } else {
    // origin outline
    BufferLifeVec life(e->bufferLifeVec.begin(), e->bufferLifeVec.end());
    MergeBufferLife(start, end, life);
    splitBound.emplace_back(std::make_shared<MemoryBound>(
        life, e->bitsOffset, e->alignedConstBits, e));
  }

  // insert tail memory bound
  if (res > 0) {
    bound = std::make_shared<MemoryBound>(last->bufferLifeVec,
                                          last->offset + last->extent - res,
                                          res, last->lastStorageEntry);
    end = outline.insert(end, bound);
  }
  // insert split e memory bound
  for (int i = static_cast<int>(splitBound.size()) - 1; i >= 0; --i) {
    end = outline.insert(end, splitBound[i]);
  }
  // record the current plan of first split entry in his
  his.emplace_back(PlanRecord{localLevel,
                              e->childIdx,
                              res > 0,
                              false,
                              splitBound.size(),
                              e,
                              e->alignedConstBits,
                              splitBound[0],
                              {},
                              outlineInfo.isDirectlyRollback});
  PlanRecord &r = his.back();
  r.replaced.splice(r.replaced.begin(), outline, start, end);
}

void MemPlan::AddMemBoundInSectionalWay(
    StorageEntry *e, MemBoundListConstIter start, MemBoundListConstIter end,
    SmallVector<std::shared_ptr<MemoryBound>> &splitBound) const {
  // |--outline1--|--outline2--|--outline3--|
  // |---------e------------ |
  // |--split e1 -|-split e2-|
  for (auto iter = start; iter != end; ++iter) {
    BufferLifeVec life(e->bufferLifeVec.begin(), e->bufferLifeVec.end());
    life.insert(life.end(), (*iter)->bufferLifeVec.begin(),
                (*iter)->bufferLifeVec.end());
    // merge the buffer life
    MergeBufferVec(life);
    // get the extent
    uint64_t size = 0;
    if (std::distance(start, iter) == std::distance(start, end) - 1) {
      // deal with the last split e2
      size = e->bitsOffset + e->alignedConstBits - (*iter)->offset;
    } else {
      size = (*iter)->extent;
    }
    splitBound.emplace_back(
        std::make_shared<MemoryBound>(life, (*iter)->offset, size, e));
  }
}

inline void MemPlan::MergeBufferLife(MemBoundList::const_iterator start,
                                     MemBoundList::const_iterator end,
                                     BufferLifeVec &newLife) const {
  size_t size = 0;
  for (auto it = start; it != end; ++it) {
    size += (*it)->bufferLifeVec.size();
  }
  newLife.reserve(size);
  for (auto it = start; it != end; ++it) {
    newLife.insert(newLife.end(), (*it)->bufferLifeVec.begin(),
                   (*it)->bufferLifeVec.end());
  }
  MergeBufferVec(newLife);
}

void MemPlan::MergeBufferVec(BufferLifeVec &bufferLife) const {
  if (bufferLife.empty()) {
    return;
  }
  BufferLifeVec mergedLife;
  mergedLife.reserve(bufferLife.size());
  // sort life by alloc and free time
  std::sort(bufferLife.begin(), bufferLife.end(), CompareBufferLife());
  int start = bufferLife[0]->allocTime;
  int end = bufferLife[0]->freeTime;
  auto buffer = bufferLife[0]->buffer;
  // merge life
  for (size_t i = 1; i < bufferLife.size(); i++) {
    auto &life = bufferLife[i];
    if (life->allocTime <= end + 1) {
      end = end < life->freeTime ? life->freeTime : end;
    } else {
      mergedLife.emplace_back(std::make_unique<BufferLife>(buffer, start, end));
      buffer = life->buffer;
      start = life->allocTime;
      end = life->freeTime;
    }
  }
  mergedLife.emplace_back(std::make_unique<BufferLife>(buffer, start, end));
  bufferLife.swap(mergedLife);
}

bool MemPlan::IsSamePlanAsLastRollBack(uint64_t allocOffset, int curChildIdx,
                                       const SpecInfo &si) const {
  return curChildIdx == si.rollbackIdx && allocOffset == si.rollbackAddr;
}

// spec_level == SPEC_LEVEL_0
inline bool
MemPlan::VerifyConflictStage0(StorageEntry *e,
                              const std::shared_ptr<MemoryBound> &last) {
  // level_0: offset = 0, offset means life distance
  DenseMap<ValuePair, BufferLife> intersection =
      GetOverlapBufferLife(e->bufferLifeVec, last->bufferLifeVec);
  return !intersection.empty();
}

// verify two buffer life vectors is conflict or not
// The key pair looks like the following diagram
// indicate that var1 is generated later than var2.
//                                     buffer2
//                               --- PLAN_time
//   buffer1       intersected   | |
//                 buffer_life   | |
// PLAN_time ---    lo ---      ---
//            | |       ---      | |
//            | |       ---      | |
//            ---    hi ---      --- free_time
//            | |
//            | |
// free_time  ---
// Meantime, the overlap is the intersected buffer_life.
DenseMap<ValuePair, BufferLife>
MemPlan::GetOverlapBufferLife(const BufferLifeVec &b1,
                              const BufferLifeVec &b2) const {
  DenseMap<ValuePair, BufferLife> intersection;
  size_t i = 0;
  size_t j = 0;
  size_t b1Len = b1.size();
  size_t b2Len = b2.size();
  if (b1Len == 0 || b2Len == 0) {
    return intersection;
  }
  while (i < b1Len && j < b2Len) {
    auto lo = std::max(b1[i]->allocTime, b2[j]->allocTime);
    auto hi = std::min(b1[i]->freeTime, b2[j]->freeTime);
    if (lo <= hi) {
      BufferLife bufferLife(nullptr, lo, hi);
      ValuePair key =
          lo == b1[i]->allocTime && hi == b2[j]->freeTime
              ? std::make_pair(b1[i]->buffer,
                               b2[j]->buffer) // case in the diagram
              : std::make_pair(b2[j]->buffer, b1[i]->buffer); // opposing case
      intersection.try_emplace(key, bufferLife);
    }
    if (b1[i]->freeTime < b2[j]->freeTime) {
      i += 1;
    } else {
      j += 1;
    }
  }
  return intersection;
}

PlanStatus MemPlan::ApplyFailStrategy(StatusWrapper &statusWrapper,
                                      const size_t maxBits) {
  RollBackForAllocFail(statusWrapper, maxBits);
  // second class rollback, level 1 --> 0
  if (statusWrapper.si->specLevel > SPEC_LEVEL_0 &&
      statusWrapper.si->childIdx >= 0) {
    statusWrapper.si->specLevel--;
    return PlanStatus::CONTINUE_PLAN;
  }
  if (!splitOutline) {
    // roll back to origin again, enable split outline.
    splitOutline = true;
    return PlanStatus::RESTART_NEW_PLAN;
  }
  return PlanStatus::PLAN_FAILED;
}

void MemPlan::ReportAllocatedEntryDebugInfo(
    const StorageEntry *rootStorageEntry) const {
#ifndef NDEBUG
  auto printRecord = [this](const StorageEntry *entry) {
    uint64_t needByte =
        (entry->alignedConstBits + utils::kBitsToByte - 1) / utils::kBitsToByte;
    uint64_t offsetByte =
        (entry->bitsOffset + utils::kBitsToByte - 1) / utils::kBitsToByte;
    ReportCurEntryDebugInfo(entry);
    LDBG(", offset: " << offsetByte);
    LDBG(", extent: " << needByte);
    LDBG(", buffer life: ");
    for (auto &bufferLife : entry->bufferLifeVec) {
      LDBG("[" << bufferLife->allocTime << "-" << bufferLife->freeTime
               << "], ");
    }
  };
  LDBG("--------------------------BUFFER ALLOCATE "
       "START-------------------------- "
       << "\n"
       << "\n");
  LDBG("  BUFFER ALLOCATE START: " << rootStorageEntry->bufInfo->bufferScope
                                   << "\n");

  auto it =
      memscope2allocatedEntry.find(rootStorageEntry->bufInfo->bufferScope);
  if (it != memscope2allocatedEntry.end() && !it->second.empty()) {
    const auto &allocatedEntry = it->second;
    for (auto &entry : allocatedEntry) {
      printRecord(entry);
      LDBG("\n");
    }
    size_t num = allocatedEntry.size() - 1;
    assert(rootStorageEntry->mergedChildren.size() > num);
    const StorageEntry *failedSe = rootStorageEntry->mergedChildren[num];
    printRecord(failedSe);
  }
  LDBG("alloc fail,because exceed bound of memory "
       << rootStorageEntry->bufInfo->bufferScope << "\n");
  LDBG("  BUFFER ALLOCATE END \n");
  LDBG("\n"
       << "--------------------------BUFFER ALLOCATE "
          "END-------------------------- "
       << "\n");
#endif
}

void MemPlan::RollBackForAllocFail(StatusWrapper &statusWrapper,
                                   const size_t maxBits) {
  while (ContinueRollBack(statusWrapper)) {
    RollBackForAllocFailInner(statusWrapper, maxBits);
  }
}

bool MemPlan::ContinueRollBack(const StatusWrapper &statusWrapper) const {
  return (!statusWrapper.hasEnoughRollBackSize) &&
         (!statusWrapper.history.empty() && (!statusWrapper.outline.empty()));
}

void MemPlan::RollBackForAllocFailInner(StatusWrapper &statusWrapper,
                                        const size_t maxBits) {
  auto &si = statusWrapper.si;
  if (si->childIdx > si->specStartIdx) {
    si->specStartIdx = si->childIdx;
  }
  // Check whether the container is empty before accessing "history"
  while (!statusWrapper.history.empty()) {
    PlanRecord r =
        RollbackOutline(statusWrapper.history, statusWrapper.outline);
    auto iter = pingEntry2RelationPongEntry.find(r.entry);
    if (iter != pingEntry2RelationPongEntry.end()) {
      pingEntry2RelationPongEntry.erase(iter);
    }
    if (r.isDirectlyRollback ||
        (r.entry->multiBufferNum == 2 && !r.entry->relationPongEntry)) {
      continue;
    }
    si->childIdx = r.childIdx;
    si->specLevel = r.specLevel;
    if (si->specLevel > si->minLevel) {
      // record rollback info: index and address
      si->rollbackAddr =
          si->childIdx == -1
              ? UINT64_MAX
              : statusWrapper.RootE->mergedChildren[si->childIdx]->bitsOffset;
      si->rollbackIdx = si->childIdx;
      if (statusWrapper.si->rollbackAddr + statusWrapper.alignedConstBits >
          maxBits) {
        continue;
      }
      statusWrapper.hasEnoughRollBackSize = true;
      break;
    }
  }
}

PlanRecord MemPlan::RollbackOutline(PlanRecHis &history,
                                    MemBoundList &outline) const {
  auto r = history.back();
  auto it = std::find(outline.begin(), outline.end(), r.firstMemBound);
  // |--head--|--split entry--|--tail--|
  // erase head
  if (r.headed) {
    it--;
    it = outline.erase(it);
  }
  // erase split entry
  for (size_t i = 0; i < r.splitNums; i++) {
    it = outline.erase(it);
  }
  // erase tail
  if (r.tailed) {
    it = outline.erase(it);
  }
  // restore outline and replaced
  outline.splice(it, r.replaced);
  history.pop_back();
  return r;
}

namespace {
struct PlanMemoryPass : public impl::PlanMemoryBase<PlanMemoryPass> {
public:
  explicit PlanMemoryPass(const PlanMemoryOptions &planMemoryOption)
      : PlanMemoryBase(planMemoryOption) {}

  void runOnOperation() override;

private:
  void populateBufferAddressToAllocOp(
      RewritePatternSet &patterns,
      DenseMap<Value, SmallVector<uint64_t>> buffer2Offsets) {
    if (this->memMode == MemPlanMode::LOCAL_MEM_PLAN) {
      patterns.add<MemrefAllocaOpToPointerCastOpPattern>(patterns.getContext(),
                                                         buffer2Offsets);
    } else {
      assert(this->memMode == MemPlanMode::GLOBAL_WORKSPACE_PLAN);
      patterns.add<UpdateWorkSpaceAllocaOpOffsetPattern>(patterns.getContext(),
                                                         buffer2Offsets);
    }
  }
  bool checkSimilarPointerCastOps(hivm::PointerCastOp op1,
                                  hivm::PointerCastOp op2) const;
  LogicalResult fixMultibufferEnabledPointerCastOps(Operation *funcOp) const;
};
} // namespace

bool PlanMemoryPass::checkSimilarPointerCastOps(hivm::PointerCastOp op1,
                                                hivm::PointerCastOp op2) const {
  return (op1->getNumResults() == op2->getNumResults() &&
          llvm::equal(op1->getResultTypes(), op2->getResultTypes())) &&
         (op1->getNumOperands() == op2->getNumOperands() &&
          llvm::equal(op1->getOperands(), op2->getOperands()));
}

LogicalResult
PlanMemoryPass::fixMultibufferEnabledPointerCastOps(Operation *funcOp) const {
  DenseMap<PointerCastOp, annotation::MarkOp> markedOps;

  funcOp->walk<WalkOrder::PreOrder>([&](annotation::MarkOp markOp) {
    auto attrDict = markOp->getAttrDictionary();
    if (attrDict.contains(hivm::MultiBufferAttr::name)) {
      if (auto pointerCastOp =
              dyn_cast<hivm::PointerCastOp>(markOp.getSrc().getDefiningOp())) {
        markedOps[pointerCastOp] = markOp;
      }
    }
  });

  for (auto [pointerCastOp, markedOp] : markedOps) {
    if (auto forOp = dyn_cast<scf::ForOp>(
            pointerCastOp->getParentOfType<scf::ForOp>())) {
      pointerCastOp->moveBefore(&forOp.getBody()->front());
      markedOp->moveAfter(pointerCastOp);
    }
  }

  std::vector<PointerCastOp> visitedOps;
  funcOp->walk<WalkOrder::PreOrder>([&](hivm::PointerCastOp pointerCastOp) {
    if (markedOps.contains(pointerCastOp)) {
      visitedOps.push_back(pointerCastOp);
    } else {
      for (auto it = visitedOps.rbegin(); it != visitedOps.rend(); it++) {
        if (!it->getOperation()->getParentOp()->isAncestor(pointerCastOp)) {
          continue;
        }
        if (checkSimilarPointerCastOps(pointerCastOp, *it)) {
          pointerCastOp.replaceAllUsesWith(it->getOperation());
          break;
        }
      }
    }
  });

  return llvm::success();
}

void PlanMemoryPass::runOnOperation() {
  auto funcOp = getOperation();
  if (hacc::utils::isHost(funcOp))
    return;

  if (this->memMode == MemPlanMode::LOCAL_MEM_PLAN) {
    RewritePatternSet normalizeLoopIterPatterns(&getContext());
    populateNormalizeLoopIneratorPattern(normalizeLoopIterPatterns);
    if (failed(applyPatternsGreedily(funcOp,
                                     std::move(normalizeLoopIterPatterns)))) {
      return signalPassFailure();
    }
  }

  MemLivenessAnalysis memLiveness(funcOp, this->memMode);
  memLiveness.build();

  MemPlan memPlan(this->memMode, this->enableGlobalReuse,
                  this->restrictInplaceAsISA);
  memPlan.func_ = funcOp;
  memPlan.SetLinearOperation(memLiveness.linearOperation);
  memPlan.SetBufferInfos(memLiveness.bufferInfos);
  memPlan.SetBuffer2Life(memLiveness.buffer2Life);
  memPlan.SetGenKillMap(memLiveness.genKillMap);
  memPlan.SetBuffer2MultiNum(memLiveness.buffer2MultiNum);
  memPlan.SetInplacePairList(memLiveness.inplacePairList);
  if (failed(memPlan.plan())) {
    return signalPassFailure();
  }

  RewritePatternSet patterns(&getContext());
  populateBufferAddressToAllocOp(patterns, memPlan.GetBuffer2Offsets());
  if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) {
    return signalPassFailure();
  }

  if (failed(fixMultibufferEnabledPointerCastOps(funcOp))) {
    return signalPassFailure();
  }
}

std::unique_ptr<Pass>
mlir::hivm::createPlanMemoryPass(const PlanMemoryOptions &planMemoryOption) {
  return std::make_unique<PlanMemoryPass>(planMemoryOption);
}
